import json
import random
import sys
import zipfile
import copy
import re
import jieba
from collections import defaultdict
import os
import pickle as pkl
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from pprint import pprint
from xbot.nlg.nlg_with_template import TemplateNLG


seed = 2019
random.seed(seed)


def split_delex_sentence(sen):
    res_sen = ""
    pattern = re.compile(r"(\[[^\[^\]]+\])")
    slots = pattern.findall(sen)
    for slot in slots:
        sen = sen.replace(slot, "[slot]")
    sen = sen.split("[slot]")
    for part in sen:
        part = " ".join(jieba.lcut(part))
        res_sen += part
        if slots:
            res_sen += " " + slots.pop(0) + " "
    return res_sen


def act2intent(dialog_act: list):
    cur_act = copy.deepcopy(dialog_act)
    if "酒店设施" in cur_act[2]:
        if cur_act[0] == "Inform":
            cur_act[2] = cur_act[2].split("-")[0] + "+" + cur_act[3]
        elif cur_act[0] == "Request":
            cur_act[2] = cur_act[2].split("-")[0]
    if cur_act[0] == "Select":
        cur_act[2] = "源领域+" + cur_act[3]
    intent = "+".join(cur_act[:-1])
    if "+".join(cur_act) == "Inform+景点+门票+免费" or cur_act[-1] == "无":
        intent = "+".join(cur_act)
    return intent


def value_replace(sentences, dialog_act):
    dialog_act = copy.deepcopy(dialog_act)
    intent_frequency = defaultdict(int)
    for act in dialog_act:
        intent = act2intent(copy.deepcopy(act))
        intent_frequency[intent] += 1
        if intent_frequency[intent] > 1:  # if multiple same intents...
            intent += str(intent_frequency[intent])

        if "酒店设施" in intent:
            try:
                sentences = sentences.replace("[" + intent + "]", act[2].split("-")[1])
                sentences = sentences.replace("[" + intent + "1]", act[2].split("-")[1])
            except Exception as e:
                print("Act causing problem in replacement:")
                pprint(act)
                raise e
        else:
            sentences = sentences.replace("[" + intent + "]", act[3])
            sentences = sentences.replace(
                "[" + intent + "1]", act[3]
            )  # if multiple same intents and this is 1st

    if "[" in sentences and "]" in sentences:
        print("\n\nValue replacement not completed!!! Current sentence: %s" % sentences)
        print("current da:")
        print(dialog_act)
        pattern = re.compile(r"(\[[^\[^\]]+\])")
        slots = pattern.findall(sentences)
        for slot in slots:
            sentences = sentences.replace(slot, " ")
        print("after replace:", sentences)
        # raise Exception('\n\nValue replacement not completed!!! Current sentence: %s' % sentences)
    return sentences


def get_bleu4(dialog_acts, golden_utts, gen_utts, data_key):
    das2utts = {}
    for das, utt, gen in zip(dialog_acts, golden_utts, gen_utts):
        intent_frequency = defaultdict(int)
        for act in das:
            cur_act = copy.copy(act)

            # intent list
            facility = None  # for 酒店设施
            if "酒店设施" in cur_act[2]:
                facility = cur_act[2].split("-")[1]
                if cur_act[0] == "Inform":
                    cur_act[2] = cur_act[2].split("-")[0] + "+" + cur_act[3]
                elif cur_act[0] == "Request":
                    cur_act[2] = cur_act[2].split("-")[0]
            if cur_act[0] == "Select":
                cur_act[2] = "源领域+" + cur_act[3]
            intent = "+".join(cur_act[:-1])
            if "+".join(cur_act) == "Inform+景点+门票+免费" or cur_act[-1] == "无":
                intent = "+".join(cur_act)

            intent_frequency[intent] += 1

            # utt content replacement
            if (
                act[0] in ["Inform", "Recommend"] or "酒店设施" in intent
            ) and not intent.endswith("无"):
                if act[3] in utt or (facility and facility in utt):
                    # value to be replaced
                    if "酒店设施" in intent:
                        value = facility
                    else:
                        value = act[3]

                    # placeholder
                    placeholder = "[" + intent + "]"
                    placeholder_one = "[" + intent + "1]"
                    placeholder_with_number = (
                        "[" + intent + str(intent_frequency[intent]) + "]"
                    )

                    if intent_frequency[intent] > 1:
                        utt = utt.replace(placeholder, placeholder_one)
                        utt = utt.replace(value, placeholder_with_number)

                        gen = gen.replace(placeholder, placeholder_one)
                        gen = gen.replace(value, placeholder_with_number)
                    else:
                        utt = utt.replace(value, placeholder)
                        gen = gen.replace(value, placeholder)

        hash_key = ""
        for act in sorted(das):
            hash_key += act2intent(act)
        das2utts.setdefault(hash_key, {"refs": [], "gens": []})
        das2utts[hash_key]["refs"].append(utt)
        das2utts[hash_key]["gens"].append({"das": das, "gen": gen})

    refs, gens = [], []
    for das in das2utts.keys():
        assert len(das2utts[das]["refs"]) == (len(das2utts[das]["gens"]))
        for gen_pair in das2utts[das]["gens"]:
            lex_das = gen_pair["das"]  # das w/ value
            gen = gen_pair["gen"]
            lex_gen = value_replace(gen, lex_das)
            gens.append([x for x in jieba.lcut(lex_gen) if x.strip()])
            refs.append(
                [
                    [x for x in jieba.lcut(value_replace(s, lex_das)) if x.strip()]
                    for s in das2utts[das]["refs"]
                ]
            )

    with open(
        os.path.join("", "generated_sens_%s.json" % data_key), "w", encoding="utf-8"
    ) as f:
        json.dump(
            {"refs": refs, "gens": gens},
            f,
            indent=4,
            sort_keys=True,
            ensure_ascii=False,
        )
    print("generated_sens_%s.txt saved!" % data_key)

    print("Start calculating bleu score...")
    bleu = corpus_bleu(
        refs,
        gens,
        weights=(0.25, 0.25, 0.25, 0.25),
        smoothing_function=SmoothingFunction().method1,
    )
    return bleu


if __name__ == "__main__":
    # if len(sys.argv) != 2:
    #     print("usage:")
    #     print("\t python evaluate.py data_key")
    #     print("\t data_key=usr/sys")
    #     sys.exit()
    # data_key = sys.argv[1]
    # assert data_key == 'usr' or data_key == 'sys'

    data_key = "usr"

    model = TemplateNLG(is_user=(data_key == "usr"), mode="auto_manual")

    ##加载测试集
    cur_dir = os.path.dirname(os.path.abspath(__file__))
    file_path = os.path.join(cur_dir, "../../data/crosswoz/raw/test.json.zip")
    archive = zipfile.ZipFile(file_path, "r")
    test_data = json.load(archive.open("tests.json"))

    dialog_acts = []
    golden_utts = []
    gen_utts = []
    gen_slots = []
    intent_list = []
    dialog_acts2genutts = defaultdict(list)

    sen_num = 0
    sess_num = 0

    # if os.path.isfile('dialog_acts_%s.pkl' % data_key) and os.path.isfile(
    #         'golden_utts_%s.pkl' % data_key) and os.path.isfile('gen_utts_%s.pkl' % data_key):
    #     with open('dialog_acts_%s.pkl' % data_key, 'rb') as fda:
    #         dialog_acts = pkl.load(fda)
    #     with open('golden_utts_%s.pkl' % data_key, 'rb') as fgold:
    #         golden_utts = pkl.load(fgold)
    #     with open('gen_utts_%s.pkl' % data_key, 'rb') as fgen:
    #         gen_utts = pkl.load(fgen)
    #     for no, sess in list(test_data.items()):
    #         sess_num += 1
    #         for turn in sess['messages']:
    #             if turn['role'] == 'usr' and data_key == 'sys':
    #                 continue
    #             elif turn['role'] == 'sys' and data_key == 'usr':
    #                 continue
    #             sen_num += 1
    #     print('sen_num: ', sen_num)
    #     print('Loaded from existing data!')
    # else:
    if True:
        for no, sess in list(test_data.items()):
            sess_num += 1
            print("[%d/%d]" % (sess_num, len(test_data)))
            for turn in sess["messages"]:
                if turn["role"] == "usr" and data_key == "sys":
                    continue
                elif turn["role"] == "sys" and data_key == "usr":
                    continue
                sen_num += 1

                dialog_acts.append(turn["dialog_act"])
                golden_utts.append(turn["content"])  # slots **values**
                gen_utt = model.generate(turn["dialog_act"])
                gen_utts.append(gen_utt)  # slots **values**

        with open("dialog_acts_%s.pkl" % data_key, "wb") as fda:
            pkl.dump(dialog_acts, fda)
        with open("golden_utts_%s.pkl" % data_key, "wb") as fgold:
            pkl.dump(golden_utts, fgold)
        with open("gen_utts_%s.pkl" % data_key, "wb") as fgen:
            pkl.dump(gen_utts, fgen)

    bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts, data_key)
    print("Calculate bleu-4")
    print("BLEU-4: %.4f" % bleu4)
    print(
        "Model on {} session {} sentences data_key={}".format(
            len(test_data), sen_num, data_key
        )
    )
